package org.zjvis.datascience.common.util.sqlParse;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLLimit;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.util.JdbcConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.dto.SqlQueryDTO;

import java.util.List;
import java.util.Map;

/**
 * @description : SQL 通用解析器
 * @date 2021-09-01
 */
public class ParseUtil {

    private final static Logger logger = LoggerFactory.getLogger("ParseUtil");

    /**
     * 解析SQL语句，生成语法树
     *
     * @param sql
     * @param tablesMap 用户表名和GP表名的映射
     * @return 修改后的JSON
     */
    public static SqlQueryDTO sqlParse(String schema, String sql, Map<String, String> tablesMap) {
        List<SQLStatement> sqlStatements = null;
        try {
            sqlStatements = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
        } catch (ParserException e) {
            logger.error("something error happened when parsing SQL. caused by {}", e.getMessage());
            return new SqlQueryDTO(400, "语法错误，请检查SQL语句");
        } catch (Exception e) {
            return new SqlQueryDTO(400, e.getMessage());
        }
        if (sqlStatements.size() > 1) {
            return new SqlQueryDTO(400, "有多条SQL语句");
        }
        SQLStatement sqlStatement = sqlStatements.get(0);
        if (!(sqlStatement instanceof SQLSelectStatement)) {
            return new SqlQueryDTO(400, "非SQL查询语句");
        }
        SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
        MySqlVisitor visitor = new MySqlVisitor(tablesMap, schema);
        sqlSelectStatement.accept(visitor);
        if (!visitor.isUserTableHit()) {
            return new SqlQueryDTO(400, "数据表 " + visitor.getUnknownTables() + " 不存在");
        }
        sql = SQLUtils.toPGString(sqlSelectStatement);
        return new SqlQueryDTO(sql);
    }


    /**
     * 给SQL语句添加 LIMIT
     *
     * @param sql SQL语句
     * @return
     */
    public static String addLimit(String sql, int limitNumber) {
        List<SQLStatement> sqlStatements = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
        SQLStatement sqlStatement = sqlStatements.get(0);
        SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
        SQLSelectQueryBlock queryBlock = sqlSelectStatement.getSelect().getQueryBlock();
        SQLLimit sqlLimit = queryBlock.getLimit();
        if (sqlLimit == null) {
            SQLExprParser sqlParser = SQLParserUtils
                    .createExprParser("limit " + limitNumber, JdbcConstants.POSTGRESQL);
            SQLLimit limit = sqlParser.parseLimit();
            queryBlock.setLimit(limit);
        } else {
            try {
                if (Integer.parseInt(sqlLimit.getRowCount().toString()) > limitNumber) {
                    sqlLimit.setRowCount(limitNumber);
                }
            } catch (Exception e) {
                sqlLimit.setRowCount(limitNumber);
            }
        }
        return SQLUtils.toPGString(sqlSelectStatement);
    }
}
